import torch
import torch.nn.functional as F

import numpy as np
import scipy.sparse as sp
import os
from torch_geometric.utils import negative_sampling
from sklearn.neighbors import kneighbors_graph


# %%
class PairwiseAttrSimTask(torch.nn.Module):
    num_classes = 2
    num_knn_neighbours = 10
    num_samples = 4000

    def __init__(self, data, embedding_size, device):
        super(PairwiseAttrSimTask, self).__init__()
        self.name = 'pair-wise attribute similarity'
        self.data = data
        self.dataset_names = [dataset.name for dataset in self.data.datasets]
        self.device = device
        self.num_nodes = [dataset.data.num_nodes for dataset in self.data.datasets]
        self.predictor = torch.nn.Linear(embedding_size, self.num_classes).to(self.device)
        self.edge_index_knn = []
        self.build_knn()

    def sample(self, dataset_name):
        index = self.dataset_names.index(dataset_name)
        labels = []
        sampled_edges = []

        # positive sampling
        num_edges = self.edge_index_knn[index].shape[1]
        idx_selected = np.random.default_rng().choice(num_edges, self.num_samples, replace=False).astype(np.int32)
        sampled_edges.append(self.edge_index_knn[index][:, idx_selected])
        labels.append(torch.ones(len(idx_selected), dtype=torch.long))

        # negative sampling
        neg_edges = negative_sampling(edge_index=self.edge_index_knn[index], num_nodes=self.num_nodes[index],
                                      num_neg_samples=self.num_samples)

        sampled_edges.append(neg_edges)
        labels.append(torch.zeros(neg_edges.shape[1], dtype=torch.long))

        sampled_edges = torch.cat(sampled_edges, axis=1)
        labels = torch.cat(labels).to(self.device)
        return sampled_edges, labels

    def get_loss(self, embeddings, dataset_name):
        node_pairs, labels = self.sample(dataset_name)
        embeddings0 = embeddings[node_pairs[0]]
        embeddings1 = embeddings[node_pairs[1]]
        embeddings = self.predictor(torch.abs(embeddings0 - embeddings1))
        output = F.log_softmax(embeddings, dim=1)
        loss = F.nll_loss(output, labels)
        return loss

    def build_knn(self):
        for dataset in self.data.datasets:
            if not os.path.exists(f'saved/{dataset.name}_knn_{self.num_knn_neighbours}.npz'):
                x = dataset.data.x.to(torch.device('cpu'))
                if x.ndim > 2:
                    x = x[:, 0, :-self.data.pos_enc_dim]
                a_knn = kneighbors_graph(x, self.num_knn_neighbours, mode='connectivity', metric='cosine',
                                         include_self=True, n_jobs=4)
                print(f'saving saved/{dataset.name}_knn_{self.num_knn_neighbours}.npz')
                sp.save_npz(f'saved/{dataset.name}_knn_{self.num_knn_neighbours}.npz', a_knn)
            else:
                print(f'loading saved/{dataset.name}_knn_{self.num_knn_neighbours}.npz')
                a_knn = sp.load_npz(f'saved/{dataset.name}_knn_{self.num_knn_neighbours}.npz')
            self.edge_index_knn.append(torch.LongTensor(a_knn.nonzero()))
